#include <torch/torch.h>
#include <iostream>
#include <cmath>
#include <cstdio>
#include <vector>

// self defined
#include "grain_growth.hpp"

auto options =
    torch::TensorOptions()
        .dtype(torch::kFloat64)
        .layout(torch::kStrided)
        .device(torch::kCPU)
        .requires_grad(true);

struct LaplacianOpImpl : torch::nn::Module {
    torch::Tensor conv_kernel;
    LaplacianOpImpl() {
        double kernel[1][1][3][3] = {{{{0.0, 1.0, 0.0}, {1.0, -4.0, 1.0}, {0.0, 1.0, 0.0}}}};
        torch::Tensor _kernel = torch::zeros({1, 1, 3, 3}, torch::dtype(torch::kFloat64).requires_grad(false));
        // _kernel.unsqueeze_(0).unsqueeze_(0); // 1 1 3 3
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                _kernel.index_put_({0, 0, i, j}, kernel[0][0][i][j]);
            }
        }
        conv_kernel = register_parameter("conv_kernel", _kernel);
    }

    torch::Tensor forward(torch::Tensor input, double dx=1.0, double dy=1.0) {
        bool unsqueezed = false;
        if(input.dim() == 2) {
            input.unsqueeze_(0);
            unsqueezed = true;
            // std::cout << "=x" << std::endl;
        }
        using namespace torch::indexing;
        // std::cout << "x" << std::endl;
        torch::Tensor input1 = torch::cat({input.index({Slice(None, None, None), Slice(-1, None, None), Slice(None, None, None)}),
                                           input, 
                                           input.index({Slice(None, None, None), Slice(None, 1, None), Slice(None, None, None)})},
                                          1).clone();
        // std::cout << "xx" << std::endl;
        torch::Tensor input2 = torch::cat({input1.index({Slice(None, None, None), Slice(None, None, None), Slice(-1, None, None)}),
                                           input1, 
                                           input1.index({Slice(None, None, None), Slice(None, None, None), Slice(None, 1, None)})},
                                          2).clone();
        // std::cout << "xxx" << std::endl;
        torch::Tensor conv_input = input2.unsqueeze_(1);
        // std::cout << "xxxx" << std::endl;
        // std::cout << "conv_input sizes: " << conv_input.sizes() << std::endl;
        // std::cout << "conv_kernel sizes: " << conv_kernel.clone().sizes() << std::endl;
        torch::Tensor result = torch::nn::functional::conv2d(input2, conv_kernel.clone(), torch::nn::functional::Conv2dFuncOptions().stride(1)).clone() / (dx * dy);
        // std::cout << "xxxxx" << std::endl;
        if(unsqueezed) {
            result.squeeze_(0);
            // std::cout << "x=" << std::endl;
        }
        // std::cout << "xxxxxx" << std::endl;
        return result;
    }
};

TORCH_MODULE_IMPL(LaplacianOp, LaplacianOpImpl);

struct ggTimeStepImpl : torch::nn::Module {
    torch::Tensor L;
    torch::Tensor A;
    torch::Tensor B;
    torch::Tensor kappa;
    valueType dt, dx, dy, eps;
    int _N;
    LaplacianOp lapop;

    int Nx, Ny;
    uint n_grains;
    valueType h, h2;

    ggTimeStepImpl(double _dt, double _dx, double _dy, double _eps, int __N, valueType _L, valueType _A, valueType _B, valueType _kappa, int _Nx, int _Ny, uint _n_grains, valueType _h) : lapop() {
        
        torch::Tensor L_value = torch::from_blob(&_L, {1,}, torch::dtype(torch::kFloat64)).clone();
        L_value.index_put_({0,}, _L);
        L = register_parameter("L", L_value.clone().requires_grad_(true));      

        torch::Tensor A_value = torch::from_blob(&_A, {1,}, torch::dtype(torch::kFloat64)).clone();
        A_value.index_put_({0,}, _A);
        A = register_parameter("A", A_value.clone().requires_grad_(true));  

        torch::Tensor B_value = torch::from_blob(&_B, {1,}, torch::dtype(torch::kFloat64)).clone();
        B_value.index_put_({0,}, _B);
        B = register_parameter("B", B_value.clone().requires_grad_(true));  

        torch::Tensor kappa_value = torch::from_blob(&_kappa, {1,}, torch::dtype(torch::kFloat64)).clone();
        kappa_value.index_put_({0,}, _kappa);
        kappa = register_parameter("kappa", kappa_value.clone().requires_grad_(true));  

        dt = _dt;
        dx = _dx;
        dy = _dy;
        eps = _eps;
        _N = __N;
        Nx = _Nx;
        Ny = _Ny;
        n_grains = _n_grains;
        h = _h;
        h2 = _h*_h;
    }

    torch::Tensor lap(torch::Tensor mat, double dx=1.0, double dy=1.0) {
        return mat * dx * dy;
        // return lapop->forward(mat, dx, dy);
    }

    torch::Tensor fix_deviations(torch::Tensor mat, double lb=0.0, double ub=1.0) {
        torch::Tensor mat_mu = mat.masked_fill(torch::ge(mat, ub).detach(), ub).clone();
        torch::Tensor mat_mul = mat.masked_fill(torch::le(mat, lb).detach(), lb).clone();
        return mat_mul;
    }

    torch::Tensor forward(torch::Tensor eta_1, torch::Tensor eta_2) { 
        // torch::Tensor sum_eta_sqr = torch::mul(eta_1, eta_1) + torch::mul(eta_2, eta_2);

        // eta_1 update
        // std::cout << "1" << std::endl;
        torch::Tensor d_energy_1 = -A*eta_1 + B*torch::pow(eta_1, 3) + 2*eta_1*(torch::pow(eta_2, 2));
        // std::cout << "2" << std::endl;
        torch::Tensor lap_eta_1 = lap(eta_1, dx, dy);
        // std::cout << "3" << std::endl;
        torch::Tensor eta_1_new = eta_1 - dt*L * (d_energy_1 - kappa*lap_eta_1);
        // std::cout << "4" << std::endl;

        eta_1_new = fix_deviations(eta_1_new);
        // std::cout << "5" << std::endl;

        // eta_2 update
        torch::Tensor d_energy_2 = -A*eta_2 + B*torch::pow(eta_2, 3) + 2*eta_2*(torch::pow(eta_1, 2));
        // std::cout << "6" << std::endl;
        torch::Tensor lap_eta_2 = lap(eta_2, dx, dy);
        // std::cout << "7" << std::endl;
        torch::Tensor eta_2_new = eta_2 - dt*L * (d_energy_2 - kappa*lap_eta_2);
        // std::cout << "8" << std::endl;

        eta_2_new = fix_deviations(eta_2_new);
        // std::cout << "9" << std::endl;

        eta_1_new.unsqueeze_(0);
        eta_2_new.unsqueeze_(0);
        torch::Tensor concate_vals = torch::cat({eta_1_new, eta_2_new});
        // std::cout << "10" << std::endl;

        return concate_vals;
    }
};

TORCH_MODULE_IMPL(ggTimeStep, ggTimeStepImpl);
